Pass input_output_alias to TritonAutotunedKernelCall#2814
Pass input_output_alias to TritonAutotunedKernelCall#2814tdophung wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: JAX Toolbox <jax@nvidia.com>
Greptile SummaryThis PR removes a workaround (WAR) that previously suppressed The fix in Key changes:
Confidence Score: 5/5Safe to merge — the WAR is preserved for all currently-released JAX versions, and the new path is correctly gated behind the upstream fix version. No P0 or P1 issues found. The alias tuple construction (input_idx, num_inputs + output_idx, size_bytes) is correct. The version gate correctly keeps the old WAR active on JAX < 0.9.3, so there is no regression risk on currently available JAX releases. The UserWarning is informative and correctly stacked. The upstream fix reference (jax-ml/jax#35218) is well-documented. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["triton_call_lowering() called\nwith input_output_aliases"] --> B{is_autotuned?}
B -- No --> C["TritonKernelCall\n(aliases via ffi_lowering only)"]
B -- Yes --> D{input_output_aliases\nis truthy?}
D -- No --> E["input_output_aliases_with_sizes = ()"]
D -- Yes --> F{JAX >=\n0.9.3?}
F -- Yes --> G["Build alias tuples:\n(input_idx,\n num_inputs + output_idx,\n size_bytes)"]
G --> H["TritonAutotunedKernelCall\nwith aliases ✓"]
F -- No --> I["UserWarning emitted\ninput_output_aliases_with_sizes = () WAR"]
I --> J["TritonAutotunedKernelCall\nwith empty aliases (safe WAR)"]
E --> J
Reviews (3): Last reviewed commit: "Add jax version guard for the input_outp..." | Re-trigger Greptile |
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI, thanks! Good idea to consolidate all the Triton+JAX version requirements in a single place!
|
/te-ci jax |
Description
https://nvbugspro.nvidia.com/bug/5810384
To remove the WAR that was put in place for this bug.
This should also serves as part 2 to WAR to the intermittent sort_chunks_by_index bug seen before in #2730
Fixes # (issue)
Type of change
Changes
if JAX version >= 0.9.3, which contains the fix to restore all aliased input buffers that was saved away during autotuning: we pass the input_output_alias tuples to TritonAutotunedKernelCall
If JAX version < 0.9.3, which does not contain the fix, we pass an empty dict to the call.
Checklist: